/**
 * Copyright 2019 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <exception>

#include "dataset/api/de_pipeline.h"
#include "dataset/kernels/no_op.h"
#include "dataset/kernels/data/one_hot_op.h"
#include "dataset/kernels/image/center_crop_op.h"
#include "dataset/kernels/image/change_mode_op.h"
#include "dataset/kernels/image/cut_out_op.h"
#include "dataset/kernels/image/decode_op.h"
#include "dataset/kernels/image/distort_bounding_box_crop_op.h"
#include "dataset/kernels/image/hwc_to_chw_op.h"
#include "dataset/kernels/image/image_utils.h"
#include "dataset/kernels/image/normalize_op.h"
#include "dataset/kernels/image/pad_op.h"
#include "dataset/kernels/image/random_color_adjust_op.h"
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
#include "dataset/kernels/image/random_crop_and_resize_op.h"
#include "dataset/kernels/image/random_crop_op.h"
#include "dataset/kernels/image/random_horizontal_flip_op.h"
#include "dataset/kernels/image/random_resize_op.h"
#include "dataset/kernels/image/random_rotation_op.h"
#include "dataset/kernels/image/random_vertical_flip_op.h"
#include "dataset/kernels/image/rescale_op.h"
#include "dataset/kernels/image/resize_bilinear_op.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#ifdef ENABLE_MINDRECORD
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#endif
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "pybind11/stl_bind.h"

namespace py = pybind11;

namespace mindspore {
namespace dataset {
#define THROW_IF_ERROR(s)                                      \
  do {                                                         \
    Status rc = std::move(s);                                  \
    if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
  } while (false)

void bindDEPipeline(py::module *m) {
  (void)py::class_<DEPipeline>(*m, "DEPipeline")
    .def(py::init<>())
    .def(
      "AddNodeToTree",
      [](DEPipeline &de, const OpName &op_name, const py::dict &args) {
        DsOpPtr op;
        THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &op));
        return op;
      },
      py::return_value_policy::reference)
    .def_static("AddChildToParentNode",
                [](const DsOpPtr &child_op, const DsOpPtr &parent_op) {
                  THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op));
                })
    .def("AssignRootNode",
         [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
    .def("SetBatchParameters",
         [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
    .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
    .def("GetNextAsMap",
         [](DEPipeline &de) {
           py::dict out;
           THROW_IF_ERROR(de.GetNextAsMap(&out));
           return out;
         })
    .def("GetNextAsList",
         [](DEPipeline &de) {
           py::list out;
           THROW_IF_ERROR(de.GetNextAsList(&out));
           return out;
         })
    .def("GetOutputShapes",
         [](DEPipeline &de) {
           py::list out;
           THROW_IF_ERROR(de.GetOutputShapes(&out));
           return out;
         })
    .def("GetOutputTypes",
         [](DEPipeline &de) {
           py::list out;
           THROW_IF_ERROR(de.GetOutputTypes(&out));
           return out;
         })
    .def("GetDatasetSize", &DEPipeline::GetDatasetSize)
    .def("GetBatchSize", &DEPipeline::GetBatchSize)
    .def("GetNumClasses", &DEPipeline::GetNumClasses)
    .def("GetRepeatCount", &DEPipeline::GetRepeatCount);
}
void bindDatasetOps(py::module *m) {
  (void)py::class_<TFReaderOp, DatasetOp, std::shared_ptr<TFReaderOp>>(*m, "TFReaderOp")
    .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) {
      int64_t count = 0;
      std::vector<std::string> filenames;
      for (auto l : files) {
        !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back("");
      }
      THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate));
      return count;
    });

  (void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp")
    .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) {
      int64_t count = 0;
      THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count));
      return count;
    });

  (void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp")
    .def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) {
      int64_t count = 0, num_classes = 0;
      THROW_IF_ERROR(
        ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes));
      return py::make_tuple(count, num_classes);
    });

#ifdef ENABLE_MINDRECORD
  (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
    .def_static("get_num_rows", [](const std::string &path) {
      int64_t count = 0;
      THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, &count));
      return count;
    });
#endif

  (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp")
    .def_static("get_num_rows_and_classes",
                [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
                  int64_t count = 0, num_classes = 0;
                  THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes));
                  return py::make_tuple(count, num_classes);
                })
    .def_static("get_class_indexing",
                [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) {
                  std::map<std::string, int32_t> output_class_indexing;
                  THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing));
                  return output_class_indexing;
                });

  (void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp")
    .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) {
      int64_t count = 0;
      THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count));
      return count;
    });
}
void bindTensor(py::module *m) {
  (void)py::class_<GlobalContext>(*m, "GlobalContext")
    .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference);

  (void)py::class_<ConfigManager, std::shared_ptr<ConfigManager>>(*m, "ConfigManager")
    .def("__str__", &ConfigManager::ToString)
    .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
    .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
    .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
    .def("set_op_connector_size", &ConfigManager::set_op_connector_size)
    .def("set_seed", &ConfigManager::set_seed)
    .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
    .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
    .def("get_worker_connector_size", &ConfigManager::worker_connector_size)
    .def("get_op_connector_size", &ConfigManager::op_connector_size)
    .def("get_seed", &ConfigManager::seed)
    .def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); });

  (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())
    .def(py::init([](py::array arr) {
      std::shared_ptr<Tensor> out;
      THROW_IF_ERROR(Tensor::CreateTensor(&out, arr));
      return out;
    }))
    .def_buffer([](Tensor &tensor) {
      py::buffer_info info;
      THROW_IF_ERROR(Tensor::GetBufferInfo(tensor, &info));
      return info;
    })
    .def("__str__", &Tensor::ToString)
    .def("shape", &Tensor::shape)
    .def("type", &Tensor::type)
    .def("as_array", [](py::object &t) {
      auto &tensor = py::cast<Tensor &>(t);
      py::buffer_info info;
      THROW_IF_ERROR(Tensor::GetBufferInfo(tensor, &info));
      return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t);
    });

  (void)py::class_<TensorShape>(*m, "TensorShape")
    .def(py::init<py::list>())
    .def("__str__", &TensorShape::ToString)
    .def("as_list", &TensorShape::AsPyList)
    .def("is_known", &TensorShape::known);

  (void)py::class_<DataType>(*m, "DataType")
    .def(py::init<std::string>())
    .def(py::self == py::self)
    .def("__str__", &DataType::ToString);
}

void bindTensorOps1(py::module *m) {
  (void)py::class_<TensorOp, std::shared_ptr<TensorOp>>(*m, "TensorOp");

  (void)py::class_<NormalizeOp, TensorOp, std::shared_ptr<NormalizeOp>>(
    *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.")
    .def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
         py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));

  (void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>(
    *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.")
    .def(py::init<float, float>(), py::arg("rescale"), py::arg("shift"));

  (void)py::class_<CenterCropOp, TensorOp, std::shared_ptr<CenterCropOp>>(
    *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)")
    .def(py::init<int32_t, int32_t>(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth);

  (void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>(
    *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode")
    .def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
         py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation);

  (void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
    *m, "ResizeBilinearOp",
    "Tensor operation to resize an image using "
    "Bilinear mode. Takes height and width.")
    .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth);

  (void)py::class_<DecodeOp, TensorOp, std::shared_ptr<DecodeOp>>(*m, "DecodeOp",
                                                                  "Tensor operation to decode a jpg image")
    .def(py::init<>())
    .def(py::init<bool>(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat);

  (void)py::class_<RandomHorizontalFlipOp, TensorOp, std::shared_ptr<RandomHorizontalFlipOp>>(
    *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.")
    .def(py::init<float>(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability);
}

void bindTensorOps2(py::module *m) {
  (void)py::class_<RandomVerticalFlipOp, TensorOp, std::shared_ptr<RandomVerticalFlipOp>>(
    *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.")
    .def(py::init<float>(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability);

  (void)py::class_<RandomCropOp, TensorOp, std::shared_ptr<RandomCropOp>>(*m, "RandomCropOp",
                                                                          "Gives random crop of specified size "
                                                                          "Takes crop size")
    .def(py::init<int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, BorderType, bool, uint8_t, uint8_t, uint8_t>(),
         py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop,
         py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft,
         py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType,
         py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR,
         py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB);
  (void)py::class_<HwcToChwOp, TensorOp, std::shared_ptr<HwcToChwOp>>(*m, "ChannelSwapOp").def(py::init<>());

  (void)py::class_<ChangeModeOp, TensorOp, std::shared_ptr<ChangeModeOp>>(
    *m, "ChangeModeOp", "Tensor operation to change colors from BGR to RGB")
    .def(py::init<>());

  (void)py::class_<OneHotOp, TensorOp, std::shared_ptr<OneHotOp>>(
    *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
    .def(py::init<int32_t>());

  (void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
    *m, "RandomRotationOp",
    "Tensor operation to apply RandomRotation."
    "Takes a range for degrees and "
    "optional parameters for rotation center and image expand")
    .def(py::init<float, float, float, float, InterpolationMode, bool, uint8_t, uint8_t, uint8_t>(),
         py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX,
         py::arg("centerY") = RandomRotationOp::kDefCenterY,
         py::arg("interpolation") = RandomRotationOp::kDefInterpolation,
         py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR,
         py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB);
}

void bindTensorOps3(py::module *m) {
  (void)py::class_<RandomCropAndResizeOp, TensorOp, std::shared_ptr<RandomCropAndResizeOp>>(
    *m, "RandomCropAndResizeOp",
    "Tensor operation to randomly crop an image and resize to a given size."
    "Takes output height and width and"
    "optional parameters for lower and upper bound for aspect ratio (h/w) and scale,"
    "interpolation mode, and max attempts to crop")
    .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
         py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb,
         py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb,
         py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb,
         py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb,
         py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation,
         py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter);

  (void)py::class_<RandomColorAdjustOp, TensorOp, std::shared_ptr<RandomColorAdjustOp>>(
    *m, "RandomColorAdjustOp",
    "Tensor operation to adjust an image's color randomly."
    "Takes range for brightness, contrast, saturation, hue and")
    .def(py::init<float, float, float, float, float, float, float, float>(), py::arg("bright_factor_start"),
         py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"),
         py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"),
         py::arg("hue_factor_end"));

  (void)py::class_<RandomResizeOp, TensorOp, std::shared_ptr<RandomResizeOp>>(
    *m, "RandomResizeOp",
    "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.")
    .def(py::init<int32_t, int32_t>(), py::arg("targetHeight"),
         py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth);

  (void)py::class_<CutOutOp, TensorOp, std::shared_ptr<CutOutOp>>(
    *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.")
    .def(py::init<int32_t, int32_t, int32_t, bool, uint8_t, uint8_t, uint8_t>(), py::arg("boxHeight"),
         py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor,
         py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG,
         py::arg("fillB") = CutOutOp::kDefFillB);
}

void bindTensorOps4(py::module *m) {
  (void)py::class_<DistortBoundingBoxCropOp, TensorOp, std::shared_ptr<DistortBoundingBoxCropOp>>(
    *m, "DistortBoundingBoxCropOp",
    "Tensor operator to crop an image randomly as long as the cropped image has sufficient "
    "overlap with any one bounding box associated with original image"
    "Takes aspect ratio of the generated crop box, the intersection ratio of crop box and bounding box,"
    "crop ratio lower and upper bounds"
    "Optional parameters: number of attempts for crop, number of attempts of crop box generation")
    .def(py::init<float, float, float, float, int32_t, int32_t>(), py::arg("aspect_ratio"), py::arg("intersect_ratio"),
         py::arg("crop_ratio_lower_bound"), py::arg("crop_ratio_upper_bound"),
         py::arg("max_attempts") = DistortBoundingBoxCropOp::kDefMaxAttempts,
         py::arg("box_gen_attempts") = DistortBoundingBoxCropOp::kDefBoxGenAttempts);

  (void)py::class_<TypeCastOp, TensorOp, std::shared_ptr<TypeCastOp>>(
    *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.")
    .def(py::init<DataType>(), py::arg("data_type"))
    .def(py::init<std::string>(), py::arg("data_type"));

  (void)py::class_<NoOp, TensorOp, std::shared_ptr<NoOp>>(*m, "NoOp",
                                                          "TensorOp that does nothing, for testing purposes only.")
    .def(py::init<>());

  (void)py::class_<ToFloat16Op, TensorOp, std::shared_ptr<ToFloat16Op>>(
    *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.")
    .def(py::init<>());

  (void)py::class_<RandomCropDecodeResizeOp, TensorOp, std::shared_ptr<RandomCropDecodeResizeOp>>(
    *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding")
    .def(py::init<int32_t, int32_t, float, float, float, float, InterpolationMode, int32_t>(), py::arg("targetHeight"),
         py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb,
         py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb,
         py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb,
         py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb,
         py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation,
         py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter);

  (void)py::class_<PadOp, TensorOp, std::shared_ptr<PadOp>>(
    *m, "PadOp",
    "Pads image with specified color, default black, "
    "Takes amount to pad for top, bottom, left, right of image, boarder type and color")
    .def(py::init<int32_t, int32_t, int32_t, int32_t, BorderType, uint8_t, uint8_t, uint8_t>(), py::arg("padTop"),
         py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType,
         py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB);
}

void bindSamplerOps(py::module *m) {
  (void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler");

  (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler")
    .def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"),
         py::arg("seed"));

  (void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
    .def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle"));

  (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
    .def(py::init<bool, int64_t>(), py::arg("replacement"), py::arg("numSamples"))
    .def(py::init<bool>(), py::arg("replacement"));

  (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler")
    .def(py::init<>());
  (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler")
    .def(py::init<std::vector<int64_t>>(), py::arg("indices"));

  (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
    .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
         py::arg("replacement"));
}

void bindInfoObjects(py::module *m) {
  (void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
    .def(py::init<int32_t, int32_t, int32_t>())
    .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
    .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
}

// This is where we externalize the C logic as python modules
PYBIND11_MODULE(_c_dataengine, m) {
  m.doc() = "pybind11 for _c_dataengine";
  (void)py::class_<DatasetOp, std::shared_ptr<DatasetOp>>(m, "DatasetOp");

  (void)py::enum_<OpName>(m, "OpName", py::arithmetic())
    .value("STORAGE", OpName::kStorage)
    .value("SHUFFLE", OpName::kShuffle)
    .value("BATCH", OpName::kBatch)
#ifdef ENABLE_MINDRECORD
    .value("MINDRECORD", OpName::kMindrecord)
#endif
    .value("CACHE", OpName::kCache)
    .value("REPEAT", OpName::kRepeat)
    .value("TAKE", OpName::kTake)
    .value("ZIP", OpName::kZip)
    .value("MAP", OpName::kMap)
    .value("FILTER", OpName::kFilter)
    .value("DEVICEQUEUE", OpName::kDeviceQueue)
    .value("GENERATOR", OpName::kGenerator)
    .export_values()
    .value("RENAME", OpName::kRename)
    .value("TFREADER", OpName::kTfReader)
    .value("PROJECT", OpName::kProject)
    .value("IMAGEFOLDER", OpName::kImageFolder)
    .value("MNIST", OpName::kMnist)
    .value("MANIFEST", OpName::kManifest)
    .value("VOC", OpName::kVoc)
    .value("CIFAR10", OpName::kCifar10)
    .value("CIFAR100", OpName::kCifar100)
    .value("CELEBA", OpName::kCelebA);

  (void)py::enum_<InterpolationMode>(m, "InterpolationMode", py::arithmetic())
    .value("DE_INTER_LINEAR", InterpolationMode::kLinear)
    .value("DE_INTER_CUBIC", InterpolationMode::kCubic)
    .value("DE_INTER_AREA", InterpolationMode::kArea)
    .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour)
    .export_values();

  (void)py::enum_<BorderType>(m, "BorderType", py::arithmetic())
    .value("DE_BORDER_CONSTANT", BorderType::kConstant)
    .value("DE_BORDER_EDGE", BorderType::kEdge)
    .value("DE_BORDER_REFLECT", BorderType::kReflect)
    .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric)
    .export_values();
  bindDEPipeline(&m);
  bindTensor(&m);
  bindTensorOps1(&m);
  bindTensorOps2(&m);
  bindTensorOps3(&m);
  bindTensorOps4(&m);
  bindSamplerOps(&m);
  bindDatasetOps(&m);
  bindInfoObjects(&m);
}
}  // namespace dataset
}  // namespace mindspore
